{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Converting Exact GP Models to TorchScript\n",
    "\n",
    "In this notebook, we'll demonstrate converting an Exact GP model to TorchScript. In general, this is the same as for standard PyTorch models where we'll use `torch.jit.trace`, but there are two pecularities to keep in mind for GPyTorch:\n",
    "\n",
    "1. The first time you make predictions with a GPyTorch model (exact or approximate), we cache certain computations. These computations can't be traced, but the results of them can be. Therefore, we'll need to pass data through the untraced model once, and then trace the model.\n",
    "1. For exact GPs, we can't trace models unless `gpytorch.settings.fast_pred_var` is used. This is a technical issue that may not be possible to overcome due to limitations on what can be traced in PyTorch; however, if you really need to trace a GP but can't use the above setting, open an issue so we have visibility on there being demand for this.\n",
    "1. You can't trace models that return Distribution objects. Therefore, we'll write a simple wrapper than unpacks the MultivariateNormal that our GPs return in to just a mean and variance tensor."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define and train an exact GP\n",
    "\n",
    "In the next cell, we define some data, define a GP model and train it. Nothing new here -- pretty much just move on to the next cell after this one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
   ],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# Training data is 100 points in [0,1] inclusive regularly spaced\n",
    "train_x = torch.linspace(0, 1, 100)\n",
    "# True function is sin(2*pi*x) with Gaussian noise\n",
    "train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2\n",
    "\n",
    "# We will use the simplest form of GP model, exact inference\n",
    "class ExactGPModel(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood):\n",
    "        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "    \n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "# initialize likelihood and model\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "model = ExactGPModel(train_x, train_y, likelihood)\n",
    "\n",
    "# this is for running the notebook in our testing framework\n",
    "import os\n",
    "smoke_test = ('CI' in os.environ)\n",
    "training_iter = 2 if smoke_test else 50\n",
    "\n",
    "\n",
    "# Find optimal model hyperparameters\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the adam optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters\n",
    "\n",
    "# \"Loss\" for GPs - the marginal log likelihood\n",
    "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
    "\n",
    "for i in range(training_iter):\n",
    "    # Zero gradients from previous iteration\n",
    "    optimizer.zero_grad()\n",
    "    # Output from model\n",
    "    output = model(train_x)\n",
    "    # Calc loss and backprop gradients\n",
    "    loss = -mll(output, train_y)\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Trace the Model\n",
    "\n",
    "In the next cell, we trace our GP model. To overcome the fact that we can't trace Modules that return Distributions, we write a wrapper Module that unpacks the GP output in to a mean and variance.\n",
    "\n",
    "\n",
    "Additionally, we'll need to run with the `gpytorch.settings.trace_mode` setting enabled, because PyTorch can't trace custom autograd Functions. Note that this results in some inefficiencies.\n",
    "\n",
    "Then, before calling `torch.jit.trace` we first call the model on `test_x`. This step is **required**, as it does some precomputation using torch functionality that cannot be traced."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MeanVarModelWrapper(torch.nn.Module):\n",
    "    def __init__(self, gp):\n",
    "        super().__init__()\n",
    "        self.gp = gp\n",
    "    \n",
    "    def forward(self, x):\n",
    "        output_dist = self.gp(x)\n",
    "        return output_dist.mean, output_dist.variance\n",
    "\n",
    "with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.trace_mode():\n",
    "    model.eval()\n",
    "    test_x = torch.linspace(0, 1, 51)\n",
    "    pred = model(test_x)  # Do precomputation\n",
    "    traced_model = torch.jit.trace(MeanVarModelWrapper(model), test_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare Predictions from TorchScript model and Torch model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.)\n",
      "tensor(0.)\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    traced_mean, traced_var = traced_model(test_x)\n",
    "\n",
    "print(torch.norm(traced_mean - pred.mean))\n",
    "print(torch.norm(traced_var - pred.variance))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "traced_model.save('traced_exact_gp.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
